# TRL တွင် GRPO ကို အကောင်အထည်ဖော်ခြင်း[[implementing-grpo-in-trl]]

ဒီစာမျက်နှာမှာ၊ Transformer Reinforcement Learning (TRL) library ကို အသုံးပြုပြီး Group Relative Policy Optimization (GRPO) ကို ဘယ်လိုအကောင်အထည်ဖော်ရမလဲဆိုတာ လေ့လာသွားမှာပါ။ ကျွန်တော်တို့ဟာ code ကို အနည်းဆုံးနဲ့ လက်တွေ့အကောင်အထည်ဖော်ခြင်းကို အဓိကထားမှာပါ။

GRPO ရဲ့ အဓိကသဘောတရားတွေကို TRL ရဲ့ GRPOTrainer မှာ ဘယ်လိုပါဝင်နေလဲဆိုတာကို လေ့လာသွားမှာဖြစ်ပြီး၊ တရားဝင် TRL documentation က snippets တွေကို လမ်းညွှန်အဖြစ် အသုံးပြုပါမယ်။

> [!TIP]
> ဒီအခန်းက TRL စတင်လေ့လာသူတွေအတွက် ရည်ရွယ်ပါတယ်။ သင် TRL ကို ကျွမ်းကျင်ပြီးသားဆိုရင်၊ GRPO ရဲ့ [Open R1 implementation](https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py) ကိုလည်း လေ့လာကြည့်နိုင်ပါတယ်။

ပထမဆုံးအနေနဲ့၊ GRPO algorithm ရဲ့ အရေးကြီးတဲ့ သဘောတရားအချို့ကို ပြန်လည်သတိရကြရအောင်။

-   **Group Formation**: model က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပါတယ်။
-   **Preference Learning**: model က completions အုပ်စုတွေကို နှိုင်းယှဉ်တဲ့ reward function ကနေ သင်ယူပါတယ်။
-   **Training Configuration**: model က training process ကို ထိန်းချုပ်ဖို့ configuration တစ်ခုကို အသုံးပြုပါတယ်။

GRPO ကို အကောင်အထည်ဖော်ဖို့ ကျွန်တော်တို့ ဘာတွေလုပ်ဖို့ လိုအပ်မလဲ။

-   prompts များ၏ dataset တစ်ခုကို သတ်မှတ်ပါ။
-   completions စာရင်းကို ယူပြီး rewards စာရင်းကို ပြန်ပေးမယ့် reward function တစ်ခုကို သတ်မှတ်ပါ။
-   training process ကို GRPOConfig တစ်ခုဖြင့် configure လုပ်ပါ။
-   GRPOTrainer ကို အသုံးပြုပြီး model ကို train လုပ်ပါ။

GRPO training ကို စတင်ဖို့အတွက် အနိမ့်ဆုံး ဥပမာတစ်ခုကတော့ အောက်ပါအတိုင်းပါ။

```python
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset

# 1. သင့် dataset ကို load လုပ်ပါ
dataset = load_dataset("your_dataset", split="train")


# 2. ရိုးရှင်းသော reward function တစ်ခုကို သတ်မှတ်ပါ
def reward_func(completions, **kwargs):
    """ဥပမာ- ပိုရှည်သော completions များကို ဆုချပါ"""
    return [float(len(completion)) for completion in completions]


# 3. Training ကို Configure လုပ်ပါ
training_args = GRPOConfig(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=10,
)

# 4. စတင်ပြီး train လုပ်ပါ
trainer = GRPOTrainer(
    model="your_model",  # ဥပမာ- "Qwen/Qwen2-0.5B-Instruct"
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_func,
)
trainer.train()
```

## အဓိက အစိတ်အပိုင်းများ

### ၁။ Dataset Format

သင့် dataset တွင် model က တုံ့ပြန်မည့် prompts များ ပါဝင်သင့်ပါတယ်။ GRPO trainer က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပြီး ၎င်းတို့ကို နှိုင်းယှဉ်ဖို့ reward function ကို အသုံးပြုပါလိမ့်မယ်။

### ၂။ Reward Function

reward function ဟာ အရေးကြီးပါတယ်။ ဒါက model က ဘယ်လိုသင်ယူတယ်ဆိုတာကို ဆုံးဖြတ်ပါတယ်။ လက်တွေ့ဥပမာ နှစ်ခုကတော့ အောက်ပါအတိုင်းပါ။

```python
# ဥပမာ ၁- completion အရှည်ပေါ်မူတည်သော reward
def reward_length(completions, **kwargs):
    return [float(len(completion)) for completion in completions]


# ဥပမာ ၂- pattern ကို ကိုက်ညီမှုအပေါ်မူတည်သော reward
import re


def reward_format(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    return [1.0 if re.match(pattern, c) else 0.0 for c in completions]
```

### ၃။ Training Configuration

`GRPOConfig` တွင် ထည့်သွင်းစဉ်းစားရမည့် အဓိက parameters များ -

```python
training_args = GRPOConfig(
    # မရှိမဖြစ်လိုအပ်သော parameters များ
    output_dir="output",
    num_train_epochs=3,
    num_generation=4,  # prompt တစ်ခုစီအတွက် ထုတ်လုပ်မည့် completions အရေအတွက်
    per_device_train_batch_size=4,  # generations အားလုံးကို device batch တစ်ခုတည်းမှာ ရယူလိုပါသည်။
    # ရွေးချယ်နိုင်သော သို့သော် အသုံးဝင်သော
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    logging_steps=10,
    # GRPO အတွက် သီးခြား (ရွေးချယ်နိုင်သော)
    use_vllm=True,  # generation ကို အရှိန်မြှင့်ရန်
)
```

`num_generation` parameter က GRPO အတွက် အထူးအရေးကြီးပါတယ်။ ဒါက group size ကို သတ်မှတ်ပါတယ်။ ဆိုလိုတာက model က prompt တစ်ခုစီအတွက် မတူညီတဲ့ completions ဘယ်နှစ်ခု ထုတ်လုပ်မလဲဆိုတာပါပဲ။ ဒါက အခြား RL methods တွေနဲ့ ကွာခြားတဲ့ အဓိကအချက်ပါ။

-   **အလွန်နည်းပါးလွန်းခြင်း (ဥပမာ- ၂-၃ ခု)**- အဓိပ္ပာယ်ရှိသော နှိုင်းယှဉ်မှုများအတွက် လုံလောက်သော ကွဲပြားမှု (diversity) ကို မပေးနိုင်ပါ။
-   **အကြံပြုထားသော (၄-၁၆ ခု)**- ကွဲပြားမှုနှင့် တွက်ချက်မှု ထိရောက်မှု (computational efficiency) အကြား ကောင်းမွန်သော ဟန်ချက်ကို ပေးပါသည်။
-   **ပိုမိုကြီးမားသော တန်ဖိုးများ**- သင်ယူမှုကို ပိုမိုကောင်းမွန်စေနိုင်သော်လည်း တွက်ချက်မှု ကုန်ကျစရိတ်ကို သိသိသာသာ တိုးမြှင့်စေသည်။

group size ကို သင့်ရဲ့ computational resources တွေနဲ့ task ရဲ့ ရှုပ်ထွေးမှုပေါ် မူတည်ပြီး ရွေးချယ်သင့်ပါတယ်။ ရိုးရှင်းတဲ့ tasks တွေအတွက်၊ သေးငယ်တဲ့ groups တွေ (၄-၈) က လုံလောက်နိုင်ပြီး၊ ပိုမိုရှုပ်ထွေးတဲ့ reasoning tasks တွေကတော့ ပိုမိုကြီးမားတဲ့ groups တွေ (၈-၁၆) ကနေ အကျိုးအမြတ်ရနိုင်ပါတယ်။

## အောင်မြင်မှုအတွက် အကြံပြုချက်များ

၁။ **Memory Management**: သင်၏ GPU memory အပေါ်မူတည်၍ `per_device_train_batch_size` နှင့် `gradient_accumulation_steps` ကို ချိန်ညှိပါ။
၂။ **Speed**: သင်၏ model ကို ထောက်ပံ့ပါက ပိုမိုမြန်ဆန်သော generation အတွက် `use_vllm=True` ကို ဖွင့်ပါ။
၃။ **Monitoring**: training လုပ်နေစဉ်အတွင်း log လုပ်ထားသော metrics များကို စောင့်ကြည့်ပါ။
    - `reward`: completions များ၏ ပျမ်းမျှ reward။
    - `reward_std`: reward groups များအတွင်းရှိ standard deviation။
    - `kl`: reference model မှ KL divergence။

## Reward Function ဒီဇိုင်း

DeepSeek R1 paper က သင်၏ GRPO implementation အတွက် လိုက်လျောညီထွေဖြစ်အောင် လုပ်ဆောင်နိုင်သော reward function ဒီဇိုင်းချခြင်း နည်းလမ်းများစွာကို ပြသထားပါတယ်။

### ၁။ Length-Based Rewards

အကောင်အထည်ဖော်ရအလွယ်ဆုံး rewards တွေထဲက တစ်ခုကတော့ length-based reward ပါပဲ။ ပိုရှည်တဲ့ completions တွေကို ဆုချနိုင်ပါတယ်။

```python
def reward_len(completions, **kwargs):
    ideal_length = 20
    return [-abs(ideal_length - len(completion)) for completion in completions]
```

ဒီ reward function က အလွန်တိုတောင်းလွန်းတဲ့ ဒါမှမဟုတ် အလွန်ရှည်လျားလွန်းတဲ့ completions တွေကို အပြစ်ပေးပါတယ်။ ဒါက model ကို ideal length ၂၀ tokens နဲ့ နီးစပ်တဲ့ completions တွေ ထုတ်လုပ်ဖို့ တိုက်တွန်းပါတယ်။

<!-- # TODO: update links when PR is merged -->

<iframe
    src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_length.py&embed=true&show-chrome=false"
    title="Marimo Notebook"
    width="100%"
    height="800px"
    frameBorder="0"
    allow="clipboard-write"
></iframe>

## ၂။ Verifiable Tasks အတွက် Rule-Based Rewards

သင်္ချာ သို့မဟုတ် coding ကဲ့သို့ တိကျမှန်ကန်သော အဖြေများရှိသည့် tasks များအတွက်၊ rule-based reward functions များကို အကောင်အထည်ဖော်နိုင်ပါတယ်။

```python
def problem_reward(completions, answers, **kwargs):
    """Verifiable အဖြေများပါသော သင်္ချာပြဿနာများအတွက် reward function
    completions: အကဲဖြတ်ရန် completions စာရင်း
    answers: dataset မှ ပြဿနာများအတွက် အဖြေများစာရင်း
    """

    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # completion မှ အဖြေကို ထုတ်ယူပါ
        try:
            # ဒါက ရိုးရှင်းတဲ့ ဥပမာတစ်ခုပါ - သင့်လျော်တဲ့ parsing လိုအပ်ပါလိမ့်မယ်
            answer = extract_final_answer(completion)
            # Binary reward: မှန်ရင် 1၊ မှားရင် 0
            reward = 1.0 if answer == correct_answer else 0.0
            rewards.append(reward)
        except:
            # အဖြေကို parse မလုပ်နိုင်ရင်၊ reward နည်းနည်း ပေးပါ
            rewards.append(0.0)

    return rewards
```

<!-- # TODO: update links when PR is merged -->

<iframe
    src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_math.py&embed=true&show-chrome=false"
    title="Marimo Notebook"
    width="100%"
    height="800px"
    frameBorder="0"
    allow="clipboard-write"
></iframe>

## ၃။ Format-Based Rewards

DeepSeek R1 training မှာ အရေးကြီးခဲ့တဲ့ သင့်လျော်တဲ့ formatting ကိုလည်း ဆုချနိုင်ပါတယ်။

```python
def format_reward(completions, **kwargs):
    """လိုချင်သော format ကို လိုက်နာသော completions များကို ဆုချပါ"""
    # ဥပမာ- completion က think-then-answer format ကို လိုက်နာခြင်းရှိမရှိ စစ်ဆေးပါ
    pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"

    rewards = []
    for completion in completions:
        match = re.search(pattern, completion, re.DOTALL)
        if match:
            # sections နှစ်ခုလုံးမှာ အဓိက အကြောင်းအရာများ ရှိမရှိ စစ်ဆေးပါ
            think_content = match.group(1).strip()
            answer_content = match.group(2).strip()

            if len(think_content) > 20 and len(answer_content) > 0:
                rewards.append(1.0)
            else:
                rewards.append(
                    0.5
                )  # မှန်ကန်သော format ဖြစ်သော်လည်း အကြောင်းအရာ နည်းပါးပါက partial reward
        else:
            rewards.append(0.0)  # format မမှန်က reward မပေးပါ

    return rewards
```

<!-- # TODO: update links when PR is merged -->

<iframe
    src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_format.py&embed=true&show-chrome=false"
    title="Marimo Notebook"
    width="100%"
    height="800px"
    frameBorder="0"
    allow="clipboard-write"
></iframe>

ဒီဥပမာတွေက DeepSeek R1 training process ကနေ လှုံ့ဆော်မှုရယူပြီး မှန်ကန်မှု၊ formatting နဲ့ ပေါင်းစပ်ထားသော signals တွေကို အဓိကထားတဲ့ reward functions တွေကို ဘယ်လိုအကောင်အထည်ဖော်ရမယ်ဆိုတာကို ပြသထားပါတယ်။

## ဒါပါပဲ!

နောက်အပိုင်းမှာ၊ TRL မှာ GRPO ကို အကောင်အထည်ဖော်ဖို့ လေ့ကျင့်ခန်းတစ်ခုကို သင်လိုက်လုပ်ရပါလိမ့်မယ်။

---

## ဝေါဟာရ ရှင်းလင်းချက် (Glossary)

*   **GRPO (Group Relative Policy Optimization)**: Reinforcement Learning (RL) algorithm တစ်ခုဖြစ်ပြီး model က ထုတ်လုပ်လိုက်တဲ့ completions အုပ်စုတွေကို နှိုင်းယှဉ်ပြီး သင်ယူကာ model ရဲ့ policy ကို optimize လုပ်ပါတယ်။
*   **TRL (Transformer Reinforcement Learning) Library**: Hugging Face မှ ထုတ်လုပ်ထားသော library တစ်ခုဖြစ်ပြီး Transformer models များကို Reinforcement Learning techniques ဖြင့် fine-tune လုပ်ရန် ရည်ရွယ်သည်။
*   **Implementation**: သီအိုရီ သို့မဟုတ် algorithm တစ်ခုကို code အဖြစ် အကောင်အထည်ဖော်ခြင်း။
*   **GRPOTrainer**: TRL library မှ GRPO algorithm ကို အကောင်အထည်ဖော်သော Trainer class။
*   **TRL Documentation**: TRL library ၏ တရားဝင်မှတ်တမ်းများ။
*   **Open R1 Implementation**: GRPO algorithm ၏ open-source အကောင်အထည်ဖော်မှု။
*   **Group Formation**: model က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပြီး အုပ်စုဖွဲ့ခြင်း။
*   **Completions**: model က prompt တစ်ခုကို တုံ့ပြန်တဲ့အနေနဲ့ ထုတ်လုပ်ပေးတဲ့ စာသား သို့မဟုတ် sequence များ။
*   **Preference Learning**: reward function မှတဆင့် completions အုပ်စုများကို နှိုင်းယှဉ်ခြင်းဖြင့် model က သင်ယူသော လုပ်ငန်းစဉ်။
*   **Reward Function**: model ၏ output (completions) များကို အကဲဖြတ်ပြီး ဂဏန်းတန်ဖိုး (reward) တစ်ခုကို ပြန်ပေးသော function။ ၎င်းသည် model ကို သင်ယူရာတွင် လမ်းညွှန်ပေးသည်။
*   **Training Configuration**: training process အတွက် parameters များနှင့် settings များကို သတ်မှတ်ခြင်း။
*   **GRPOConfig**: TRL library မှ GRPO training အတွက် configuration များကို ထိန်းချုပ်သော class။
*   **Dataset of Prompts**: model က တုံ့ပြန်ရန်အတွက် အသုံးပြုမည့် prompts များပါဝင်သော dataset။
*   **`trl`**: Transformer Reinforcement Learning library။
*   **`GRPOTrainer`**: TRL မှ GRPO algorithm အတွက် Trainer class။
*   **`GRPOConfig`**: TRL မှ GRPO training အတွက် configuration class။
*   **`load_dataset`**: Hugging Face Datasets library မှ dataset များကို load လုပ်ရန် function။
*   **`output_dir`**: trained model နှင့် logs များကို သိမ်းဆည်းမည့် directory။
*   **`num_train_epochs`**: training လုပ်မည့် epochs အရေအတွက်။
*   **`per_device_train_batch_size`**: device တစ်ခုစီ (ဥပမာ- GPU) အတွက် batch size။
*   **`gradient_accumulation_steps`**: gradients များကို update မလုပ်မီ batches မည်မျှစုဆောင်းမည်ကို သတ်မှတ်ခြင်း။
*   **`logging_steps`**: training log များကို မည်သည့် step အရေအတွက်တိုင်းတွင် မှတ်တမ်းတင်မည်ကို သတ်မှတ်ခြင်း။
*   **`model` (argument in `GRPOTrainer`)**: အသုံးပြုမည့် base model ၏ identifier သို့မဟုတ် instance။
*   **`args` (argument in `GRPOTrainer`)**: training configuration arguments များ။
*   **`train_dataset`**: training အတွက် အသုံးပြုမည့် dataset။
*   **`reward_funcs`**: reward function (များ)။
*   **`trainer.train()`**: training process ကို စတင်ရန် method။
*   **Prompts**: model ကို တုံ့ပြန်စေလိုသော စာသား input များ။
*   **`re` Module**: Python ၏ regular expression module။
*   **`re.match()`**: string ၏ အစမှ pattern ကို ကိုက်ညီမှုရှိမရှိ စစ်ဆေးရန် function။
*   **`num_generation`**: prompt တစ်ခုစီအတွက် model က ထုတ်လုပ်မည့် completions အရေအတွက်။ ၎င်းသည် GRPO ၏ group size ဖြစ်သည်။
*   **RL Methods (Reinforcement Learning Methods)**: trial-and-error မှတစ်ဆင့် သင်ယူပြီး reward အများဆုံးရရှိရန် ကြိုးစားသော Machine Learning algorithms များ။
*   **Diversity**: ထုတ်လုပ်လိုက်သော completions များ၏ ကွဲပြားမှု။
*   **Computational Efficiency**: တွက်ချက်မှုအရင်းအမြစ်များကို မည်မျှထိရောက်စွာ အသုံးပြုသည်ကို ဆိုလိုသည်။
*   **Computational Cost**: တွက်ချက်မှု လုပ်ဆောင်ရန် လိုအပ်သော အချိန်နှင့် အရင်းအမြစ်များ။
*   **Reasoning Tasks**: အကြောင်းပြချက်၊ ဆင်ခြင်တုံတရား လိုအပ်သော လုပ်ငန်းများ။
*   **Memory Management**: ကွန်ပျူတာ၏ မှတ်ဉာဏ် (memory) အသုံးပြုမှုကို ထိန်းချုပ်ခြင်း။
*   **GPU Memory**: Graphics Processing Unit (GPU) တွင်ရှိသော မှတ်ဉာဏ်။
*   **`use_vllm=True`**: vLLM (a high-throughput inference engine) ကို အသုံးပြု၍ generation ကို အရှိန်မြှင့်ရန်။
*   **Logged Metrics**: training လုပ်နေစဉ်အတွင်း မှတ်တမ်းတင်ထားသော တိုင်းတာမှုများ။
*   **`reward` (metric)**: completions များ၏ ပျမ်းမျှ reward တန်ဖိုး။
*   **`reward_std` (metric)**: reward groups များအတွင်းရှိ rewards များ၏ standard deviation။
*   **`kl` (metric)**: KL divergence (Kullback-Leibler divergence) ကို ရည်ညွှန်းပြီး reference model မှ policy က မည်မျှကွာခြားသည်ကို တိုင်းတာသည်။
*   **DeepSeek R1 Paper**: DeepSeek R1 model နှင့် ၎င်း၏ training method များကို ဖော်ပြထားသော research paper။
*   **Length-Based Reward**: completion ၏ အရှည်ပေါ်မူတည်၍ ပေးသော reward။
*   **`ideal_length`**: completion အတွက် လိုချင်သော အရှည်။
*   **`abs()`**: ဂဏန်းတစ်ခု၏ absolute value (အနုတ်လက္ခဏာမပါသော တန်ဖိုး)။
*   **Verifiable Tasks**: အဖြေကို တိကျစွာ စစ်ဆေးအတည်ပြုနိုင်သော လုပ်ငန်းများ။
*   **Rule-Based Reward Functions**: သတ်မှတ်ထားသော စည်းမျဉ်းများ သို့မဟုတ် အခြေအနေများအပေါ် အခြေခံ၍ reward ပေးသော function များ။
*   **`extract_final_answer()`**: completion မှ နောက်ဆုံးအဖြေကို ထုတ်ယူရန် ဒီဇိုင်းထုတ်ထားသော function (ဥပမာတွင် ရိုးရှင်းထားသည်)။
*   **Binary Reward**: 0 သို့မဟုတ် 1 ကဲ့သို့သော တန်ဖိုးနှစ်ခုသာ ရှိသော reward (မှန်/မှား)။
*   **Parsing**: စာသားကို ခွဲခြမ်းစိတ်ဖြာပြီး အဓိပ္ပာယ်ဖော်ခြင်း။
*   **Format-Based Rewards**: completion ၏ formatting (ပုံစံချထားမှု) အပေါ်မူတည်၍ ပေးသော reward။
*   **`re.search()`**: string တစ်ခုအတွင်း pattern ကို ရှာဖွေရန် function။
*   **`re.DOTALL`**: regular expression flags တစ်ခုဖြစ်ပြီး `.` (dot) သည် newline character (`\n`) အပါအဝင် မည်သည့် character ကိုမဆို ကိုက်ညီစေသည်။
*   **`match.group(1)` / `match.group(2)`**: regular expression match object မှ သက်ဆိုင်ရာ capture group ၏ contents များကို ထုတ်ယူခြင်း။
*   **`strip()`**: string တစ်ခု၏ အစ သို့မဟုတ် အဆုံးရှိ whitespace များကို ဖယ်ရှားခြင်း။
*   **Partial Reward**: အပြည့်အဝ reward မဟုတ်ဘဲ တစ်စိတ်တစ်ပိုင်း reward။